#include "natsocketfactory.h"

#include "logging.h"
#include "natserver.h"
#include "virtualsocketserver.h"

namespace base {

	// Packs the given socketaddress into the buffer in buf, in the quasi-STUN
	// format that the natserver uses.
	// Returns 0 if an invalid address is passed.
	size_t PackAddressForNAT(char* buf, size_t buf_size,
		const SocketAddress& remote_addr) {
			const IPAddress& ip = remote_addr.ipaddr();
			int family = ip.family();
			buf[0] = 0;
			buf[1] = family;
			// Writes the port.
			*(reinterpret_cast<uint16*>(&buf[2])) = HostToNetwork16(remote_addr.port());
			if (family == AF_INET) {
				ASSERT(buf_size >= kNATEncodedIPv4AddressSize);
				in_addr v4addr = ip.ipv4_address();
				std::memcpy(&buf[4], &v4addr, kNATEncodedIPv4AddressSize - 4);
				return kNATEncodedIPv4AddressSize;
			} else if (family == AF_INET6) {
				ASSERT(buf_size >= kNATEncodedIPv6AddressSize);
				in6_addr v6addr = ip.ipv6_address();
				std::memcpy(&buf[4], &v6addr, kNATEncodedIPv6AddressSize - 4);
				return kNATEncodedIPv6AddressSize;
			}
			return 0U;
	}

	// Decodes the remote address from a packet that has been encoded with the nat's
	// quasi-STUN format. Returns the length of the address (i.e., the offset into
	// data where the original packet starts).
	size_t UnpackAddressFromNAT(const char* buf, size_t buf_size,
		SocketAddress* remote_addr) {
			ASSERT(buf_size >= 8);
			ASSERT(buf[0] == 0);
			int family = buf[1];
			uint16 port = NetworkToHost16(*(reinterpret_cast<const uint16*>(&buf[2])));
			if (family == AF_INET) {
				const in_addr* v4addr = reinterpret_cast<const in_addr*>(&buf[4]);
				*remote_addr = SocketAddress(IPAddress(*v4addr), port);
				return kNATEncodedIPv4AddressSize;
			} else if (family == AF_INET6) {
				ASSERT(buf_size >= 20);
				const in6_addr* v6addr = reinterpret_cast<const in6_addr*>(&buf[4]);
				*remote_addr = SocketAddress(IPAddress(*v6addr), port);
				return kNATEncodedIPv6AddressSize;
			}
			return 0U;
	}


	// NATSocket
	class NATSocket : public AsyncSocket, public sigslot::has_slots<> {
	public:
		explicit NATSocket(NATInternalSocketFactory* sf, int family, int type)
			: sf_(sf), family_(family), type_(type), async_(true), connected_(false),
			socket_(NULL), buf_(NULL), size_(0) {
		}

		virtual ~NATSocket() {
			delete socket_;
			delete[] buf_;
		}

		virtual SocketAddress GetLocalAddress() const {
			return (socket_) ? socket_->GetLocalAddress() : SocketAddress();
		}

		virtual SocketAddress GetRemoteAddress() const {
			return remote_addr_;  // will be NIL if not connected
		}

		virtual int Bind(const SocketAddress& addr) {
			if (socket_) {  // already bound, bubble up error
				return -1;
			}

			int result;
			socket_ = sf_->CreateInternalSocket(family_, type_, addr, &server_addr_);
			result = (socket_) ? socket_->Bind(addr) : -1;
			if (result >= 0) {
				socket_->SignalConnectEvent.connect(this, &NATSocket::OnConnectEvent);
				socket_->SignalReadEvent.connect(this, &NATSocket::OnReadEvent);
				socket_->SignalWriteEvent.connect(this, &NATSocket::OnWriteEvent);
				socket_->SignalCloseEvent.connect(this, &NATSocket::OnCloseEvent);
			} else {
				server_addr_.Clear();
				delete socket_;
				socket_ = NULL;
			}

			return result;
		}

		virtual int Connect(const SocketAddress& addr) {
			if (!socket_) {  // socket must be bound, for now
				return -1;
			}

			int result = 0;
			if (type_ == SOCK_STREAM) {
				result = socket_->Connect(server_addr_.IsNil() ? addr : server_addr_);
			} else {
				connected_ = true;
			}

			if (result >= 0) {
				remote_addr_ = addr;
			}

			return result;
		}

		virtual int Send(const void* data, size_t size) {
			ASSERT(connected_);
			return SendTo(data, size, remote_addr_);
		}

		virtual int SendTo(const void* data, size_t size, const SocketAddress& addr) {
			ASSERT(!connected_ || addr == remote_addr_);
			if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
				return socket_->SendTo(data, size, addr);
			}
			// This array will be too large for IPv4 packets, but only by 12 bytes.
			scoped_array<char> buf(new char[size + kNATEncodedIPv6AddressSize]);
			size_t addrlength = PackAddressForNAT(buf.get(),
				size + kNATEncodedIPv6AddressSize,
				addr);
			size_t encoded_size = size + addrlength;
			std::memcpy(buf.get() + addrlength, data, size);
			int result = socket_->SendTo(buf.get(), encoded_size, server_addr_);
			if (result >= 0) {
				ASSERT(result == static_cast<int>(encoded_size));
				result = result - static_cast<int>(addrlength);
			}
			return result;
		}

		virtual int Recv(void* data, size_t size) {
			SocketAddress addr;
			return RecvFrom(data, size, &addr);
		}

		virtual int RecvFrom(void* data, size_t size, SocketAddress *out_addr) {
			if (server_addr_.IsNil() || type_ == SOCK_STREAM) {
				return socket_->RecvFrom(data, size, out_addr);
			}
			// Make sure we have enough room to read the requested amount plus the
			// largest possible header address.
			SocketAddress remote_addr;
			Grow(size + kNATEncodedIPv6AddressSize);

			// Read the packet from the socket.
			int result = socket_->RecvFrom(buf_, size_, &remote_addr);
			if (result >= 0) {
				ASSERT(remote_addr == server_addr_);

				// TODO: we need better framing so we know how many bytes we can
				// return before we need to read the next address. For UDP, this will be
				// fine as long as the reader always reads everything in the packet.
				ASSERT((size_t)result < size_);

				// Decode the wire packet into the actual results.
				SocketAddress real_remote_addr;
				size_t addrlength =
					UnpackAddressFromNAT(buf_, result, &real_remote_addr);
				std::memcpy(data, buf_ + addrlength, result - addrlength);

				// Make sure this packet should be delivered before returning it.
				if (!connected_ || (real_remote_addr == remote_addr_)) {
					if (out_addr)
						*out_addr = real_remote_addr;
					result = result - static_cast<int>(addrlength);
				} else {
					LOG(LS_ERROR) << "Dropping packet from unknown remote address: "
						<< real_remote_addr.ToString();
					result = 0;  // Tell the caller we didn't read anything
				}
			}

			return result;
		}

		virtual int Close() {
			int result = 0;
			if (socket_) {
				result = socket_->Close();
				if (result >= 0) {
					connected_ = false;
					remote_addr_ = SocketAddress();
					delete socket_;
					socket_ = NULL;
				}
			}
			return result;
		}

		virtual int Listen(int backlog) {
			return socket_->Listen(backlog);
		}
		virtual AsyncSocket* Accept(SocketAddress *paddr) {
			return socket_->Accept(paddr);
		}
		virtual int GetError() const {
			return socket_->GetError();
		}
		virtual void SetError(int error) {
			socket_->SetError(error);
		}
		virtual ConnState GetState() const {
			return connected_ ? CS_CONNECTED : CS_CLOSED;
		}
		virtual int EstimateMTU(uint16* mtu) {
			return socket_->EstimateMTU(mtu);
		}
		virtual int GetOption(Option opt, int* value) {
			return socket_->GetOption(opt, value);
		}
		virtual int SetOption(Option opt, int value) {
			return socket_->SetOption(opt, value);
		}

		void OnConnectEvent(AsyncSocket* socket) {
			// If we're NATed, we need to send a request with the real addr to use.
			ASSERT(socket == socket_);
			if (server_addr_.IsNil()) {
				connected_ = true;
				SignalConnectEvent(this);
			} else {
				SendConnectRequest();
			}
		}
		void OnReadEvent(AsyncSocket* socket) {
			// If we're NATed, we need to process the connect reply.
			ASSERT(socket == socket_);
			if (type_ == SOCK_STREAM && !server_addr_.IsNil() && !connected_) {
				HandleConnectReply();
			} else {
				SignalReadEvent(this);
			}
		}
		void OnWriteEvent(AsyncSocket* socket) {
			ASSERT(socket == socket_);
			SignalWriteEvent(this);
		}
		void OnCloseEvent(AsyncSocket* socket, int error) {
			ASSERT(socket == socket_);
			SignalCloseEvent(this, error);
		}

	private:
		// Makes sure the buffer is at least the given size.
		void Grow(size_t new_size) {
			if (size_ < new_size) {
				delete[] buf_;
				size_ = new_size;
				buf_ = new char[size_];
			}
		}

		// Sends the destination address to the server to tell it to connect.
		void SendConnectRequest() {
			char buf[256];
			size_t length = PackAddressForNAT(buf, ARRAY_SIZE(buf), remote_addr_);
			socket_->Send(buf, length);
		}

		// Handles the byte sent back from the server and fires the appropriate event.
		void HandleConnectReply() {
			char code;
			socket_->Recv(&code, sizeof(code));
			if (code == 0) {
				SignalConnectEvent(this);
			} else {
				Close();
				SignalCloseEvent(this, code);
			}
		}

		NATInternalSocketFactory* sf_;
		int family_;
		int type_;
		bool async_;
		bool connected_;
		SocketAddress remote_addr_;
		SocketAddress server_addr_;  // address of the NAT server
		AsyncSocket* socket_;
		char* buf_;
		size_t size_;
	};

	// NATSocketFactory
	NATSocketFactory::NATSocketFactory(SocketFactory* factory,
		const SocketAddress& nat_addr)
		: factory_(factory), nat_addr_(nat_addr) {
	}

	Socket* NATSocketFactory::CreateSocket(int type) {
		return CreateSocket(AF_INET, type);
	}

	Socket* NATSocketFactory::CreateSocket(int family, int type) {
		return new NATSocket(this, family, type);
	}

	AsyncSocket* NATSocketFactory::CreateAsyncSocket(int type) {
		return CreateAsyncSocket(AF_INET, type);
	}

	AsyncSocket* NATSocketFactory::CreateAsyncSocket(int family, int type) {
		return new NATSocket(this, family, type);
	}

	AsyncSocket* NATSocketFactory::CreateInternalSocket(int family, int type,
		const SocketAddress& local_addr, SocketAddress* nat_addr) {
			*nat_addr = nat_addr_;
			return factory_->CreateAsyncSocket(family, type);
	}

	// NATSocketServer
	NATSocketServer::NATSocketServer(SocketServer* server)
		: server_(server), msg_queue_(NULL) {
	}

	NATSocketServer::Translator* NATSocketServer::GetTranslator(
		const SocketAddress& ext_ip) {
			return nats_.Get(ext_ip);
	}

	NATSocketServer::Translator* NATSocketServer::AddTranslator(
		const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
			// Fail if a translator already exists with this extternal address.
			if (nats_.Get(ext_ip))
				return NULL;

			return nats_.Add(ext_ip, new Translator(this, type, int_ip, server_, ext_ip));
	}

	void NATSocketServer::RemoveTranslator(
		const SocketAddress& ext_ip) {
			nats_.Remove(ext_ip);
	}

	Socket* NATSocketServer::CreateSocket(int type) {
		return CreateSocket(AF_INET, type);
	}

	Socket* NATSocketServer::CreateSocket(int family, int type) {
		return new NATSocket(this, family, type);
	}

	AsyncSocket* NATSocketServer::CreateAsyncSocket(int type) {
		return CreateAsyncSocket(AF_INET, type);
	}

	AsyncSocket* NATSocketServer::CreateAsyncSocket(int family, int type) {
		return new NATSocket(this, family, type);
	}

	AsyncSocket* NATSocketServer::CreateInternalSocket(int family, int type,
		const SocketAddress& local_addr, SocketAddress* nat_addr) {
			AsyncSocket* socket = NULL;
			Translator* nat = nats_.FindClient(local_addr);
			if (nat) {
				socket = nat->internal_factory()->CreateAsyncSocket(family, type);
				*nat_addr = (type == SOCK_STREAM) ?
					nat->internal_tcp_address() : nat->internal_address();
			} else {
				socket = server_->CreateAsyncSocket(family, type);
			}
			return socket;
	}

	// NATSocketServer::Translator
	NATSocketServer::Translator::Translator(
		NATSocketServer* server, NATType type, const SocketAddress& int_ip,
		SocketFactory* ext_factory, const SocketAddress& ext_ip)
		: server_(server) {
			// Create a new private network, and a NATServer running on the private
			// network that bridges to the external network. Also tell the private
			// network to use the same message queue as us.
			VirtualSocketServer* internal_server = new VirtualSocketServer(server_);
			internal_server->SetMessageQueue(server_->queue());
			internal_factory_.reset(internal_server);
			nat_server_.reset(new NATServer(type, internal_server, int_ip,
				ext_factory, ext_ip));
	}


	NATSocketServer::Translator* NATSocketServer::Translator::GetTranslator(
		const SocketAddress& ext_ip) {
			return nats_.Get(ext_ip);
	}

	NATSocketServer::Translator* NATSocketServer::Translator::AddTranslator(
		const SocketAddress& ext_ip, const SocketAddress& int_ip, NATType type) {
			// Fail if a translator already exists with this extternal address.
			if (nats_.Get(ext_ip))
				return NULL;

			AddClient(ext_ip);
			return nats_.Add(ext_ip,
				new Translator(server_, type, int_ip, server_, ext_ip));
	}
	void NATSocketServer::Translator::RemoveTranslator(
		const SocketAddress& ext_ip) {
			nats_.Remove(ext_ip);
			RemoveClient(ext_ip);
	}

	bool NATSocketServer::Translator::AddClient(
		const SocketAddress& int_ip) {
			// Fail if a client already exists with this internal address.
			if (clients_.find(int_ip) != clients_.end())
				return false;

			clients_.insert(int_ip);
			return true;
	}

	void NATSocketServer::Translator::RemoveClient(
		const SocketAddress& int_ip) {
			std::set<SocketAddress>::iterator it = clients_.find(int_ip);
			if (it != clients_.end()) {
				clients_.erase(it);
			}
	}

	NATSocketServer::Translator* NATSocketServer::Translator::FindClient(
		const SocketAddress& int_ip) {
			// See if we have the requested IP, or any of our children do.
			return (clients_.find(int_ip) != clients_.end()) ?
				this : nats_.FindClient(int_ip);
	}

	// NATSocketServer::TranslatorMap
	NATSocketServer::TranslatorMap::~TranslatorMap() {
		for (TranslatorMap::iterator it = begin(); it != end(); ++it) {
			delete it->second;
		}
	}

	NATSocketServer::Translator* NATSocketServer::TranslatorMap::Get(
		const SocketAddress& ext_ip) {
			TranslatorMap::iterator it = find(ext_ip);
			return (it != end()) ? it->second : NULL;
	}

	NATSocketServer::Translator* NATSocketServer::TranslatorMap::Add(
		const SocketAddress& ext_ip, Translator* nat) {
			(*this)[ext_ip] = nat;
			return nat;
	}

	void NATSocketServer::TranslatorMap::Remove(
		const SocketAddress& ext_ip) {
			TranslatorMap::iterator it = find(ext_ip);
			if (it != end()) {
				delete it->second;
				erase(it);
			}
	}

	NATSocketServer::Translator* NATSocketServer::TranslatorMap::FindClient(
		const SocketAddress& int_ip) {
			Translator* nat = NULL;
			for (TranslatorMap::iterator it = begin(); it != end() && !nat; ++it) {
				nat = it->second->FindClient(int_ip);
			}
			return nat;
	}

}  // namespace base
